{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "from sklearn import metrics\n",
    "import transformers\n",
    "import time\n",
    "import torch\n",
    "from tqdm import tqdm\n",
    "from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler\n",
    "from transformers import BertTokenizer, BertModel, BertConfig\n",
    "\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_df = pd.read_csv('../data/train_set.csv', sep='\\t')\n",
    "test_df = pd.read_csv('../data/test_a.csv', sep='\\t')\n",
    "test_df['label'] = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = BertTokenizer.from_pretrained('../emb/bert-mini/vocab.txt')\n",
    "tokenizer.encode_plus(\"2967 6758 339 2021 1854\",\n",
    "        add_special_tokens=True,\n",
    "        max_length=20,\n",
    "        truncation=True)\n",
    "# token_type_ids 通常第一个句子全部标记为0，第二个句子全部标记为1。\n",
    "# attention_mask padding的地方为0，未padding的地方为1。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CustomDataset(Dataset):\n",
    "\n",
    "    def __init__(self, dataframe, tokenizer, max_len):\n",
    "        self.tokenizer = tokenizer\n",
    "        self.data = dataframe\n",
    "        self.comment_text = self.data.text\n",
    "        self.targets = self.data.label\n",
    "        self.max_len = max_len\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.comment_text)\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "        comment_text = self.comment_text[index]\n",
    "\n",
    "        inputs = self.tokenizer.encode_plus(\n",
    "            comment_text,\n",
    "            truncation=True,\n",
    "            add_special_tokens=True,\n",
    "            max_length=self.max_len,\n",
    "            pad_to_max_length=True,\n",
    "            return_token_type_ids=True\n",
    "        )\n",
    "        ids = inputs['input_ids']\n",
    "        mask = inputs['attention_mask']\n",
    "        token_type_ids = inputs[\"token_type_ids\"]\n",
    "\n",
    "        return {\n",
    "            'ids': torch.tensor(ids, dtype=torch.long),\n",
    "            'mask': torch.tensor(mask, dtype=torch.long),\n",
    "            'token_type_ids': torch.tensor(token_type_ids, dtype=torch.long),\n",
    "            'targets': torch.tensor(self.targets[index], dtype=torch.float)\n",
    "        }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Creating the dataset and dataloader for the neural network\n",
    "MAX_LEN = 256\n",
    "train_size = 0.8\n",
    "train_dataset = train_df.sample(frac=train_size,random_state=7)\n",
    "valid_dataset = train_df.drop(train_dataset.index).reset_index(drop=True)\n",
    "train_dataset = train_dataset.reset_index(drop=True)\n",
    "\n",
    "\n",
    "print(\"FULL Dataset: {}\".format(train_df.shape))\n",
    "print(\"TRAIN Dataset: {}\".format(train_dataset.shape))\n",
    "print(\"VALID Dataset: {}\".format(valid_dataset.shape))\n",
    "print(\"TEST Dataset: {}\".format(test_df.shape))\n",
    "\n",
    "train_set = CustomDataset(train_dataset, tokenizer, MAX_LEN)\n",
    "valid_set = CustomDataset(valid_dataset, tokenizer, MAX_LEN)\n",
    "test_set = CustomDataset(test_df, tokenizer, MAX_LEN)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "TRAIN_BATCH_SIZE = 32\n",
    "VALID_BATCH_SIZE = 16\n",
    "TEST_BATCH_SIZE = 16\n",
    "train_params = {'batch_size': TRAIN_BATCH_SIZE,\n",
    "                'shuffle': True}\n",
    "\n",
    "valid_params = {'batch_size': VALID_BATCH_SIZE,\n",
    "                'shuffle': True}\n",
    "\n",
    "test_params = {'batch_size': TEST_BATCH_SIZE,\n",
    "                'shuffle': False}\n",
    "\n",
    "train_loader = DataLoader(train_set, **train_params)\n",
    "valid_loader = DataLoader(valid_set, **valid_params)\n",
    "test_loader = DataLoader(test_set, **test_params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Creating the customized model, by adding a drop out and a dense layer on top of distil bert to get the final output for the model. \n",
    "\n",
    "class BERTClass(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        super(BERTClass, self).__init__()\n",
    "        self.config = BertConfig.from_pretrained('../emb/bert-mini/bert_config.json', output_hidden_states=True)\n",
    "        self.l1 = BertModel.from_pretrained('../emb/bert-mini/pytorch_model.bin', config=self.config)\n",
    "        self.bilstm1 = torch.nn.LSTM(512, 64, 1, bidirectional=True)\n",
    "        self.l2 = torch.nn.Linear(128, 64)\n",
    "        self.a1 = torch.nn.ReLU()\n",
    "        self.l3 = torch.nn.Dropout(0.3)\n",
    "        self.l4 = torch.nn.Linear(64, 14)\n",
    "    \n",
    "    def forward(self, ids, mask, token_type_ids):\n",
    "        sequence_output, pooler_output, hidden_states= self.l1(ids, attention_mask=mask, token_type_ids=token_type_ids)\n",
    "        # [bs, 200, 256]  [bs,256]\n",
    "        bs = len(sequence_output)\n",
    "        h12 = hidden_states[-1][:,0].view(1,bs,256)\n",
    "        h11 = hidden_states[-2][:,0].view(1,bs,256)\n",
    "        concat_hidden = torch.cat((h12,h11),2)\n",
    "        x, _ = self.bilstm1(concat_hidden)\n",
    "        x = self.l2(x.view(bs,128))\n",
    "        x = self.a1(x)\n",
    "        x = self.l3(x)\n",
    "        output = self.l4(x)\n",
    "        return output\n",
    "\n",
    "net = BERTClass()\n",
    "net.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 超参数设置\n",
    "lr, num_epochs = 1e-5, 30\n",
    "criterion = torch.nn.CrossEntropyLoss()  # 选择损失函数\n",
    "optimizer = torch.optim.Adam(net.parameters(), lr=lr)  # 选择优化器"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate_accuracy(data_iter, net, device=torch.device('cpu')):\n",
    "    \"\"\"Evaluate accuracy of a model on the given data set.\"\"\"\n",
    "    acc_sum, n = torch.tensor([0], dtype=torch.float32,device=device), 0\n",
    "    y_pred_, y_true_ = [], []\n",
    "    for data in tqdm(data_iter):\n",
    "        # If device is the GPU, copy the data to the GPU.\n",
    "        ids = data['ids'].to(device, dtype = torch.long)\n",
    "        mask = data['mask'].to(device, dtype = torch.long)\n",
    "        token_type_ids = data['token_type_ids'].to(device, dtype = torch.long)\n",
    "        targets = data['targets'].to(device, dtype = torch.float)\n",
    "        net.eval()\n",
    "        y_hat_ = net(ids, mask, token_type_ids)\n",
    "        with torch.no_grad():\n",
    "            targets = targets.long()\n",
    "            # [[0.2 ,0.4 ,0.5 ,0.6 ,0.8] ,[ 0.1,0.2 ,0.4 ,0.3 ,0.1]] => [ 4 , 2 ]\n",
    "            acc_sum += torch.sum((torch.argmax(y_hat_, dim=1) == targets))\n",
    "            y_pred_.extend(torch.argmax(y_hat_, dim=1).cpu().numpy().tolist())\n",
    "            y_true_.extend(targets.cpu().numpy().tolist())\n",
    "            n += targets.shape[0]\n",
    "    valid_f1 = metrics.f1_score(y_true_, y_pred_, average='macro')\n",
    "    return acc_sum.item()/n, valid_f1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(epoch,train_iter, test_iter, criterion, num_epochs, optimizer, device):\n",
    "    print('training on', device)\n",
    "    net.to(device)\n",
    "    best_test_f1 = 0\n",
    "    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)  # 设置学习率下降策略\n",
    "#     scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=5, eta_min=2e-06)  # 余弦退火\n",
    "    for epoch in range(num_epochs):\n",
    "        train_l_sum = torch.tensor([0.0], dtype=torch.float32, device=device)\n",
    "        train_acc_sum = torch.tensor([0.0], dtype=torch.float32, device=device)\n",
    "        n, start = 0, time.time()\n",
    "        y_pred, y_true = [], []\n",
    "        for data in tqdm(train_iter):\n",
    "            net.train()\n",
    "            optimizer.zero_grad()\n",
    "            ids = data['ids'].to(device, dtype=torch.long)\n",
    "            mask = data['mask'].to(device, dtype=torch.long)\n",
    "            token_type_ids = data['token_type_ids'].to(device, dtype=torch.long)\n",
    "            targets = data['targets'].to(device, dtype = torch.float)\n",
    "            y_hat = net(ids, mask, token_type_ids)\n",
    "            loss = criterion(y_hat, targets.long())\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            \n",
    "            with torch.no_grad():\n",
    "                targets = targets.long()\n",
    "                train_l_sum += loss.float()\n",
    "                train_acc_sum += (torch.sum((torch.argmax(y_hat, dim=1) == targets))).float()\n",
    "                y_pred.extend(torch.argmax(y_hat, dim=1).cpu().numpy().tolist())\n",
    "                y_true.extend(targets.cpu().numpy().tolist())\n",
    "                n += targets.shape[0]\n",
    "        valid_acc, valid_f1 = evaluate_accuracy(test_iter, net, device)\n",
    "        train_acc = train_acc_sum / n\n",
    "        train_f1 = metrics.f1_score(y_true, y_pred, average='macro')\n",
    "        print('epoch %d, loss %.4f, train acc %.3f, valid acc %.3f, '\n",
    "              'train f1 %.3f, valid f1 %.3f, time %.1f sec'\n",
    "              % (epoch + 1, train_l_sum / n, train_acc, valid_acc,\n",
    "                 train_f1, valid_f1, time.time() - start))\n",
    "        if valid_f1 > best_test_f1:\n",
    "            print('find best! save at model/best.pth')\n",
    "            best_test_f1 = valid_f1\n",
    "            torch.save(net.state_dict(), 'model/best.pth')\n",
    "        scheduler.step()  # 更新学习率"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train(net,train_loader, valid_loader, criterion, num_epochs, optimizer, device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def model_predict(net, test_iter):\n",
    "    # 预测模型\n",
    "    preds_list = []\n",
    "    print('加载最优模型')\n",
    "    net.load_state_dict(torch.load('model/best.pth'))\n",
    "    net = net.to(device)\n",
    "    print('inference测试集')\n",
    "    with torch.no_grad():\n",
    "        for data in tqdm(test_iter):\n",
    "            ids = data['ids'].to(device, dtype=torch.long)\n",
    "            mask = data['mask'].to(device, dtype=torch.long)\n",
    "            token_type_ids = data['token_type_ids'].to(device, dtype=torch.long)\n",
    "            batch_preds = list(net(ids, mask, token_type_ids).argmax(dim=1).cpu().numpy())\n",
    "            for preds in batch_preds:\n",
    "                preds_list.append(preds)           \n",
    "    return preds_list\n",
    "preds_list = model_predict(net, test_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "submission = pd.read_csv('../data/test_a_sample_submit.csv')\n",
    "submission['label'] = preds_list\n",
    "submission.to_csv('../output/submission.csv', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
